package sqlparser

import (
	"db-sharding/system/log4go"
	"github.com/cznic/mathutil"
	"github.com/emirpasic/gods/sets/hashset"
	"regexp"
	"strings"
)

const (
	from         = " from "
	join         = " join "
	commentStart = "/**"
	commentEnd   = "*/"
	commentKey   = "lc="
	splitWord    = "_"
	into         = " into "
	update       = "update "
	where        = " where "
)

var SelectKeywords = [...]string{from, join}
var InsertKeywords = [...]string{into, from}
var UpdateKeywords = [...]string{update}
var DeleteKeywords = [...]string{from}

func RewriteDeleteStmt(logicSql string) (string, string) {
	lowerSql := trimSql(logicSql)
	tableList := hashset.New()
	//根据from分析，from后面必然跟着一个表明

	for _, keyword := range DeleteKeywords {
		getTableNameAfterKeyword(lowerSql, keyword, tableList)
	}

	return getActualSql(logicSql, lowerSql, tableList)
}

func RewriteUpdateStmt(logicSql string) (string, string) {
	lowerSql := trimSql(logicSql)
	tableList := hashset.New()
	//根据from分析，from后面必然跟着一个表明

	for _, keyword := range UpdateKeywords {
		getTableNameAfterKeyword(lowerSql, keyword, tableList)
	}

	return getActualSql(logicSql, lowerSql, tableList)
}

func RewriteSelectStmt(logicSql string) (string, string) {
	lowerSql := trimSql(logicSql)
	tableList := hashset.New()
	//根据from分析，from后面必然跟着一个表明

	for _, keyword := range SelectKeywords {
		getTableNameAfterKeyword(lowerSql, keyword, tableList)
	}

	return getActualSql(logicSql, lowerSql, tableList)
}

func trimSql(logicSql string) string {
	return strings.ReplaceAll(strings.ReplaceAll(strings.ToLower(strings.TrimSpace(logicSql)), "\n", " "), "\t", " ")
}

func RewriteInsertStmt(logicSql string) (string, string) {
	lowerSql := trimSql(logicSql)

	tableList := hashset.New()
	//根据from分析，from后面必然跟着一个表明

	for _, keyword := range InsertKeywords {
		getTableNameAfterKeyword(lowerSql, keyword, tableList)
	}

	return getActualSql(logicSql, lowerSql, tableList)
}

func getActualSql(logicSql string, lowerSql string, list *hashset.Set) (string, string) {
	sql := strings.ReplaceAll(strings.ReplaceAll(logicSql, "\n", " "), "\t", " ")
	lotCode := GetLotCodeFromComment(lowerSql)

	if len(lotCode) == 0 {
		return logicSql, ""
	}

	if list.Size() > 0 {
		var logicTableName string
		for _, table := range list.Values() {
			logicTableName = table.(string)
			sql = strings.ReplaceAll(sql, " "+logicTableName, GetActualTableName(logicTableName, lotCode))
		}
	}
	log4go.Info("sql：", sql)
	return sql, lotCode
}

const lotCodeRegexpStr = "lotcodeassist\\s*=\\s*'?\\d+'?"

var lotCodeRegexp = regexp.MustCompile(lotCodeRegexpStr)

//从sql头部注释中获取车场id
//注释必须以这种形式开头：/** lc=2007*/
func GetLotCodeFromComment(sql string) string {
	index := strings.Index(sql, commentStart)

	var lotCode string

	if index > -1 {
		comment := (sql)[index+len(commentStart):]
		comment = comment[0:strings.Index(comment, commentEnd)]
		index = strings.Index(comment, commentKey)
		if index > -1 {
			lotCode = strings.TrimSpace(comment[index+len(commentKey):])
		}
	}
	if len(lotCode) == 0 || lotCode == "null" {
		//从正则表达式中获取
		log4go.Debug("无法从注解中获取车场路由，使用正则匹配")

		res := lotCodeRegexp.FindAllString(sql, 1)
		if len(res) > 0 {
			lotCode = res[0]
			lotCode = lotCode[strings.Index(lotCode, "=")+1:]
			lotCode = strings.TrimSpace(strings.ReplaceAll(lotCode, "'", ""))
		}

		//log4go.Debug("正则匹配结果：", lotCode)
	}

	return lotCode
}

func GetActualTableName(logicTableName string, lotCode string) string {
	i := strings.Index(logicTableName, splitWord)

	if i > 0 {
		logicTableName = logicTableName[0:i] + splitWord + lotCode + logicTableName[i:]
	}
	logicTableName = " " + logicTableName
	return logicTableName
}

func getTableNameAfterKeyword(lowerSql string, keyword string, tableList *hashset.Set) {
	sql := lowerSql
	index := 0
	for index > -1 {
		index = strings.Index(sql, keyword)
		if index > -1 {
			//截取到关键词后面 select * from a,b where 1=1
			sql = strings.TrimSpace(sql[index+len(keyword):])
			//从第一个字符切割到空格
			tableName := sql[0:strings.Index(sql, " ")]
			if tableName[0:1] == "(" {
				//子查询，跳过
				continue
			}

			if keyword == from {
				//log4go.Debug("from关键词，需要处理内联查询")
				//截取到where或者join
				whereIndex := strings.Index(sql, where)
				joinIndex := strings.Index(sql, join)
				subIndex := mathutil.Min(whereIndex, joinIndex)
				if subIndex < 0 {
					subIndex = mathutil.Max(whereIndex, joinIndex)
				}
				if subIndex > 0 {
					innerJoinTables := sql[0:subIndex]
					if strings.Index(innerJoinTables, ",") > 0 {
						tables := strings.Split(innerJoinTables, ",")
						for _, table := range tables {
							table = strings.TrimSpace(table)

							items := strings.Split(table, " ")

							for _, item := range items {
								if strings.Index(item, "_") > 0 {
									tableList.Add(strings.TrimSpace(item))
								}
							}
						}

						continue
					}
				}

			}
			tableList.Add(tableName)

		} else {
			break
		}
	}
}
